/**
 * 
 */
package edu.umd.clip.lm.ngram;

import java.io.IOException;
import java.util.*;

import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.model.data.*;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class TrainingDataSampler {
	private final NgramModel models[];
	private final int order;

	public TrainingDataSampler(NgramModel models[]) {
		this.models = models;
		
		Experiment exp = Experiment.getInstance();
		this.order = exp.getLM().getOrder();
	}
	
	public void doSampling(long sampleSize, WritableTrainingData data) throws IOException {
		Context context = new Context(order);
		TrainingDataBlock block = new TrainingDataBlock();
		block = doSampling(sampleSize, data, block, 0, context, 0);
		if (block.hasData()) {
			data.add(block);
		}
	}
	
	private TrainingDataBlock doSampling(long sampleSize, WritableTrainingData data, TrainingDataBlock block, 
			int currentModel, Context currentContext, long futureTuple) 
	throws IOException 
	{
		NgramModel model = models[currentModel];
		CtxVar contextVariables[] = model.getContextVariables();
		
		int ctx[] = new int[contextVariables.length];
		
		for(int i=0; i<ctx.length; ++i) {
			if (contextVariables[i].getOffset() == 0) {
				ctx[i] = contextVariables[i].getValue(futureTuple);
			} else {
				ctx[i] = contextVariables[i].getValue(currentContext.data);
			}
		}

		int vocab[] = model.getFutureVocab(ctx); 
		long counts[] = new long[vocab.length];
		//{
			double cumulativeDistribution[] = new double[vocab.length];
			
			double sum = 0;
			for(int i=0; i<vocab.length; ++i) {
				double prob = model.getProb(vocab[i], ctx);
				sum += prob;
				cumulativeDistribution[i] = sum;
			}
			
			// TODO: an expensive way to generate counts!
			Random rnd = new Random();
			for(long i=0; i<sampleSize; ++i) {
				// sum should be very close to 1.0
				double prob = rnd.nextDouble() * sum;
				int idx = Arrays.binarySearch(cumulativeDistribution, prob);
				if (idx < 0) idx = - idx - 1;
				if (idx == counts.length) --idx;
				counts[idx]++;
			}
		//}
		
		final ArrayList<TupleCountPair> futures = currentModel == models.length - 1 ? new ArrayList<TupleCountPair>() : null;
		
		
		for(int i=0; i<counts.length; ++i) {
			long count = counts[i];
			if (count == 0) continue;
			
			int word = vocab[i];
			if (model.getFutureVariable().getOffset() < 0) {
				model.getFutureVariable().setValue(currentContext.data, word);
			} else {
				futureTuple = model.getFutureVariable().setValue(futureTuple, word); 
			}
			
			if (futures != null) {
				// done sampling, output the data
				futures.add(new TupleCountPair(futureTuple, (int) count));
			} else {
				block = doSampling(count, data, block, currentModel+1, currentContext, futureTuple);
			}
		}
		
		if (futures != null) {
			Context contextCopy = (Context) currentContext.clone();
			ContextFuturesPair pair = new ContextFuturesPair(contextCopy, futures.toArray(new TupleCountPair[futures.size()]));
			if (!block.add(pair)) {
				data.add(block);
				block = new TrainingDataBlock();
				block.add(pair);
			}
		}
		return block;
	}
	

}
